[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860
[JAX] Fix BF16 tolerance for CGEMM + RS + BF16 test#2860phu0ngng merged 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L0 |
Greptile SummaryThis PR increases Confidence Score: 5/5Safe to merge — the change is a well-justified tolerance relaxation for a single test path, backed by a mathematical derivation and a standalone reproducer. Only one file changes, the logic is correct (condition precisely identifies RS+BF16 without quantization), the new atol=0.125 is derived from first principles (2× the worst-case 1-ULP BF16 difference at O(8) scale), and assert_allclose correctly handles both-provided vs. both-None paths. No functional code is modified. No files require special attention.
|
| Filename | Overview |
|---|---|
| examples/jax/collective_gemm/test_gemm.py | Adds a targeted tolerance override (rtol=1e-2, atol=0.125) only for the CGEMM+RS+BF16 path, with a detailed comment explaining the reduction-order mismatch; all other paths continue to use dtype-default tolerances. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[run_gemm_tests] --> B{enable_result_check\nand process_id == 0?}
B -- No --> Z[Skip check]
B -- Yes --> C{collective_op == REDUCE_SCATTER\nand not use_quantization?}
C -- Yes\nis_cgemm_rs_bf16=True --> D["rtol = 1e-2\natol = 0.125\n(covers 1 BF16 ULP near-zero)"]
C -- No\nis_cgemm_rs_bf16=False --> E["rtol = None\natol = None\n(use dtype defaults)"]
D --> F[assert_allclose\ngathered_ref_output vs gathered_output]
E --> F
F --> G{Both rtol and atol\nnot None?}
G -- Yes --> H["Use provided\nrtol=1e-2, atol=0.125"]
G -- No --> I["Fall back to\ndtype_tols(bfloat16)\nrtol=1e-2, atol=1e-5"]
H --> J[np.testing.assert_allclose]
I --> J
Reviews (2): Last reviewed commit: "Merge branch 'main' into cgemm_bf16_fix_..." | Re-trigger Greptile
Description
atol=1e-5was too strict for BF16 comparisons between the NONE collective GEMM and Collective GEMM with RS collective paths. Both paths split K across TP ranks and produce identical BF16 partial GEMMs, but reduce them in different orders:((p0+p1)+(p2+p3))— binary tree in FP32 → BF16((p0+p1)+p2)+p3— sequential in FP32 → BF16Different reduction associativity causes rounding differences of up to 1 BF16 ULP of the partial GEMM magnitude. The combined tolerance
atol + rtol*|ref|covers this across all output scales:|ref| > atol/rtol = 12.5):rtol=1e-2dominates and provides sufficient coverage.rtolprovides no coverage, soatol=0.125(2× the worst-case 1-ULP diff at O(8) scale) is needed.atol=1e-5failed because it is far below 1 ULP at any realistic activation magnitude.Reproducer
The mismatch is verified by a standalone test (https://gist.github.com/phu0ngng/9600caf76df6040ecc4b3f3c6ea20882) that mimics the two collective paths on a single GPU:
test_gemm.py(M=8192, K_tp=1024, N=16384, seed=PRNGKey(0)).C_none) and TE sequential order (C_rs).2 elements differ by exactly 1 BF16 ULP.
Type of change
Changes
Increase
atolfrom1e-5to0.125to cover the near-zero regime wherertolprovides no coverage. Large-magnitude diffs (the common case) are already handled byrtol=1e-2.Checklist: